# Some synthetic search space objects to iterate quickly on

from search_spaces.base import SearchSpace
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
import numpy as np
from search_spaces.brax_env import Brax


class DummySearchSpace(SearchSpace):
    def __init__(self, offset=0.1, do_nas: bool = False, do_hpo: bool = True):
        super().__init__()
        self.config_space = self.get_configspace(do_nas, do_hpo)
        self.offset = offset
        self.do_nas = do_nas
        self.do_hpo = do_hpo
        self.env_name = 'dummy'
        self.max_parallel = 100

    def get_configspace(self, do_nas: bool = True,
                        do_hpo: bool = True,
                        choose_spectral_norm: bool = False,
                        use_same_arch_policy_q: bool = False):
        """

        :param do_nas: whether to do neural architecture search
        :param do_hpo: whether to do hyperparameter optimisation
        :param choose_spectral_norm: whether select between the MLP/MLP with spectral norm. If False, we will simply
            use MLP without spectral norm as the default architecture.
        :param use_same_arch_policy_q: whether constrain the policy and q networks such that they have the same
            architecture

        :return:
        """
        cs = CS.ConfigurationSpace()  # configspace seems buggy in testing some of the inactive values
        # the nas hyperparameter start with NAS_ prefix!
        if do_nas:
            nets = ['policy'] if use_same_arch_policy_q else ['policy', 'q']
            for net in nets:  # for sac, tune policy and action networks separately
                # MLP or spectral norm MLP
                if choose_spectral_norm:
                    use_spectral_norm = CSH.CategoricalHyperparameter(f'NAS_{net}_use_spectral_norm',
                                                                      choices=[True, False], default_value=False)
                    cs.add_hyperparameter(use_spectral_norm)
                num_hidden_layers = CSH.UniformIntegerHyperparameter(f'NAS_{net}_num_layers', 1, 8, default_value=2)
                network_width = CSH.UniformIntegerHyperparameter(f'NAS_{net}_log2_width', 5, 10, default_value=5)
                cs.add_hyperparameter(num_hidden_layers)
                cs.add_hyperparameter(network_width)

                # cs.add_hyperparameter(num_hidden_layers)
                # number of neurons in each hidden layer is a conditional hyperparameter -- no. neurons in layer i only
                #   exists iff we have a neural network with >= i layers
                # for n in range(1, 8 + 1):
                #     log_n_neurons_hidden_layer = CSH.UniformIntegerHyperparameter(f'NAS_{net}_log2_n_neurons_layer_{n}',
                #                                                                   5,
                #                                                                   10,
                #                                                                   default_value=5)
                #     cond = CS.GreaterThanCondition(log_n_neurons_hidden_layer, num_hidden_layers, n)
                #     cond1 = CS.EqualsCondition(log_n_neurons_hidden_layer, num_hidden_layers, n)
                #     cs.add_hyperparameter(log_n_neurons_hidden_layer)
                #     cs.add_condition(CS.OrConjunction(cond, cond1))
                # additionally tune the activation function for each of the archs

                # activation = CSH.CategoricalHyperparameter(f'NAS_{net}_activation',
                #                                            choices=['relu', 'sigmoid', 'tanh', 'elu',
                #                                                     'silu', 'hardswish', ],
                #                                            default_value='silu')
                # cs.add_hyperparameter(activation)

        if do_hpo:
            hyperparams = [CSH.UniformFloatHyperparameter('log10_lr', -4, -3, default_value=np.log10(3e-4)),
                           CSH.UniformFloatHyperparameter('discounting', 0.9, 0.9999, default_value=0.97),
                           CSH.UniformFloatHyperparameter('log10_entropy_cost', -6, -1, default_value=-2),
                           CSH.UniformIntegerHyperparameter('unroll_length', 5, 50, default_value=5),
                           CSH.UniformIntegerHyperparameter('log2_batch_size', 6, 10, default_value=10),
                           # CSH.UniformIntegerHyperparameter('mini_batches_multiplier', 3, 6, default_value=4),
                           # note that n_mini_batches * batch_size % env_size must be 0.
                           CSH.UniformIntegerHyperparameter('num_update_epochs', 2, 30, default_value=4),
                           CSH.UniformFloatHyperparameter('reward_scaling', 0.05, 20, default_value=10, log=True),
                           CSH.UniformFloatHyperparameter('lambda_', 0.9, 1, default_value=0.95),
                           CSH.UniformFloatHyperparameter('ppo_epsilon', 0.01, 0.3, default_value=0.2)]
            # this should always be true, hence commented out as a hyperparameter for now
            # hyperparams.append(CSH.CategoricalHyperparameter('normalize_observations', choices=[True, False], default_value=True)
            cs.add_hyperparameters(hyperparams)

        return cs

    def train_single(self, config, exp_idx=0, num_timesteps=0, **kwargs):
        coeff = np.random.RandomState(0).uniform(low=-1., high=1., size=config.get_array().shape[0])
        coeff2 = np.random.RandomState(1).uniform(low=-1., high=1., size=config.get_array().shape[0])
        rewards = np.nansum((config.get_array() - self.offset) * coeff) + np.nansum(
            coeff2 * (config.get_array() - self.offset) ** 2)
        traj = {
            'x': np.arange(100),
            'y': rewards * np.ones(100)
        }
        return traj

    def train_batch(self, configs, exp_idx_start=0, **kwargs):
        return [self.train_single(c, 0) for c in configs]


class PB2MixerSyntheticProblem(SearchSpace):
    """
    Implementation of the simple 1-dim cateogrical, 1-dim continuous problem in the PB2 Mixer paper
    [1] J. Parker-Holder, V. Nguyen, S. Desai, and S. Roberts, “Tuning Mixed Input Hyperparameters on the Fly for
        Efficient Population Based AutoRL,” 2021.
    """

    def __init__(self, scale=1.0, time_varying_period=20, timestep_increment: int = int(1e5)):
        """
        time_varying: if True, at regular time interval, the function value is flipped to craete a step change. the
            goal is to see whether the time-varying GP surrogate is able to learn it
        """
        super(PB2MixerSyntheticProblem, self).__init__()
        # define a configspace for this problem
        self.scale = scale
        self.time_varying_period = time_varying_period
        self.timestep_increment = timestep_increment
        self.config_space = self.get_configspace()
        self.n_queries = 0
        self.env_name = 'synthetic'

    def get_configspace(self):
        cs = CS.ConfigurationSpace()
        cs.add_hyperparameter(CSH.CategoricalHyperparameter('categorical', choices=[0, 1], default_value=0))
        cs.add_hyperparameter(CSH.UniformFloatHyperparameter('continuous', 0., np.pi / 2., default_value=0.5))
        return cs

    def train_single(self, config, exp_idx=0, num_timesteps=0, **kwargs):
        h, x = config['categorical'], config['continuous']

        if (self.n_queries // self.time_varying_period) % 2 or (self.time_varying_period < 0):
            if h == 0: func = np.sin
            else: func = np.cos
        else:
            if h == 0: func = np.cos
            else: func = np.sin

        # if self.time_varying_period > 0:
        #     time_factor = 1. + 0.3 * np.sin(self.n_queries / self.time_varying_period)
        # else:
        time_factor = 1.

        traj = {
            'x': [self.timestep_increment],
            'y': [(func(x) + np.random.normal(0, 0.01) * time_factor)] # the negative to convert the problem to minimization
        }
        self.n_queries += 1
        return traj

    def train_batch(self, configs, exp_idx_start=0, **kwargs):
        res = [self.train_single(c, 0) for c in configs]
        self.n_queries += len(configs)
        return res